/*

 * Licensed to the Apache Software Foundation (ASF) under one

 * or more contributor license agreements.  See the NOTICE file

 * distributed with this work for additional information

 * regarding copyright ownership.  The ASF licenses this file

 * to you under the Apache License, Version 2.0 (the

 * "License"); you may not use this file except in compliance

 * with the License.  You may obtain a copy of the License at

 *

 *     http://www.apache.org/licenses/LICENSE-2.0

 *

 * Unless required by applicable law or agreed to in writing, software

 * distributed under the License is distributed on an "AS IS" BASIS,

 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.

 * See the License for the specific language governing permissions and

 * limitations under the License.

 */

package com.bff.gaia.unified.fn.harness;



import com.google.auto.service.AutoService;

import com.bff.gaia.unified.fn.harness.control.BundleSplitListener;

import com.bff.gaia.unified.fn.harness.data.UnifiedFnDataClient;

import com.bff.gaia.unified.fn.harness.data.PCollectionConsumerRegistry;

import com.bff.gaia.unified.fn.harness.data.PTransformFunctionRegistry;

import com.bff.gaia.unified.fn.harness.state.UnifiedFnStateClient;

import com.bff.gaia.unified.model.pipeline.v1.RunnerApi;

import com.bff.gaia.unified.model.pipeline.v1.RunnerApi.CombinePayload;

import com.bff.gaia.unified.model.pipeline.v1.RunnerApi.PCollection;

import com.bff.gaia.unified.model.pipeline.v1.RunnerApi.PTransform;

import com.bff.gaia.unified.model.pipeline.v1.RunnerApi.StandardPTransforms;

import com.bff.gaia.unified.runners.core.construction.UnifiedUrns;

import com.bff.gaia.unified.runners.core.construction.RehydratedComponents;

import com.bff.gaia.unified.sdk.coders.Coder;

import com.bff.gaia.unified.sdk.coders.KvCoder;

import com.bff.gaia.unified.sdk.fn.data.FnDataReceiver;

import com.bff.gaia.unified.sdk.function.ThrowingFunction;

import com.bff.gaia.unified.sdk.options.PipelineOptions;

import com.bff.gaia.unified.sdk.transforms.Combine.CombineFn;

import com.bff.gaia.unified.sdk.util.SerializableUtils;

import com.bff.gaia.unified.sdk.util.WindowedValue;

import com.bff.gaia.unified.sdk.util.WindowedValue.WindowedValueCoder;

import com.bff.gaia.unified.sdk.values.KV;

import com.bff.gaia.unified.vendor.guava.com.google.common.annotations.VisibleForTesting;

import com.bff.gaia.unified.vendor.guava.com.google.common.collect.ImmutableMap;

import com.bff.gaia.unified.vendor.guava.com.google.common.collect.Iterables;



import java.io.IOException;

import java.util.Map;

import java.util.function.Supplier;



/** Executes different components of Combine PTransforms. */

public class CombineRunners {



  /** A registrar which provides a factory to handle combine component PTransforms. */

  @AutoService(PTransformRunnerFactory.Registrar.class)

  public static class Registrar implements PTransformRunnerFactory.Registrar {



    @Override

    public Map<String, PTransformRunnerFactory> getPTransformRunnerFactories() {

      return ImmutableMap.of(

          UnifiedUrns.getUrn(StandardPTransforms.CombineComponents.COMBINE_PER_KEY_PRECOMBINE),

          new PrecombineFactory(),

          UnifiedUrns.getUrn(StandardPTransforms.CombineComponents.COMBINE_PER_KEY_MERGE_ACCUMULATORS),

          MapFnRunners.forValueMapFnFactory(CombineRunners::createMergeAccumulatorsMapFunction),

          UnifiedUrns.getUrn(StandardPTransforms.CombineComponents.COMBINE_PER_KEY_EXTRACT_OUTPUTS),

          MapFnRunners.forValueMapFnFactory(CombineRunners::createExtractOutputsMapFunction),

          UnifiedUrns.getUrn(StandardPTransforms.CombineComponents.COMBINE_GROUPED_VALUES),

          MapFnRunners.forValueMapFnFactory(CombineRunners::createCombineGroupedValuesMapFunction));

    }

  }



  private static class PrecombineRunner<KeyT, InputT, AccumT> {

    private PipelineOptions options;

    private CombineFn<InputT, AccumT, ?> combineFn;

    private FnDataReceiver<WindowedValue<KV<KeyT, AccumT>>> output;

    private Coder<KeyT> keyCoder;

    private GroupingTable<WindowedValue<KeyT>, InputT, AccumT> groupingTable;

    private Coder<AccumT> accumCoder;



    PrecombineRunner(

        PipelineOptions options,

        CombineFn<InputT, AccumT, ?> combineFn,

        FnDataReceiver<WindowedValue<KV<KeyT, AccumT>>> output,

        Coder<KeyT> keyCoder,

        Coder<AccumT> accumCoder) {

      this.options = options;

      this.combineFn = combineFn;

      this.output = output;

      this.keyCoder = keyCoder;

      this.accumCoder = accumCoder;

    }



    void startBundle() {

      groupingTable =

          PrecombineGroupingTable.combiningAndSampling(

              options, combineFn, keyCoder, accumCoder, 0.001 /*sizeEstimatorSampleRate*/);

    }



    void processElement(WindowedValue<KV<KeyT, InputT>> elem) throws Exception {

      groupingTable.put(

          elem, (Object outputElem) -> output.accept((WindowedValue<KV<KeyT, AccumT>>) outputElem));

    }



    void finishBundle() throws Exception {

      groupingTable.flush(

          (Object outputElem) -> output.accept((WindowedValue<KV<KeyT, AccumT>>) outputElem));

    }

  }



  /** A factory for {@link PrecombineRunner}s. */

  @VisibleForTesting

  public static class PrecombineFactory<KeyT, InputT, AccumT>

      implements PTransformRunnerFactory<PrecombineRunner<KeyT, InputT, AccumT>> {



    @Override

    public PrecombineRunner<KeyT, InputT, AccumT> createRunnerForPTransform(

        PipelineOptions pipelineOptions,

        UnifiedFnDataClient unifiedFnDataClient,

        UnifiedFnStateClient unifiedFnStateClient,

        String pTransformId,

        PTransform pTransform,

        Supplier<String> processBundleInstructionId,

        Map<String, PCollection> pCollections,

        Map<String, RunnerApi.Coder> coders,

        Map<String, RunnerApi.WindowingStrategy> windowingStrategies,

        PCollectionConsumerRegistry pCollectionConsumerRegistry,

        PTransformFunctionRegistry startFunctionRegistry,

        PTransformFunctionRegistry finishFunctionRegistry,

        BundleSplitListener splitListener)

        throws IOException {

      // Get objects needed to create the runner.

      RehydratedComponents rehydratedComponents =

          RehydratedComponents.forComponents(

              RunnerApi.Components.newBuilder()

                  .putAllCoders(coders)

                  .putAllWindowingStrategies(windowingStrategies)

                  .build());

      String mainInputTag = Iterables.getOnlyElement(pTransform.getInputsMap().keySet());

      RunnerApi.PCollection mainInput = pCollections.get(pTransform.getInputsOrThrow(mainInputTag));



      // Input coder may sometimes be WindowedValueCoder depending on runner, instead of the

      // expected KvCoder.

      Coder<?> uncastInputCoder = rehydratedComponents.getCoder(mainInput.getCoderId());

      KvCoder<KeyT, InputT> inputCoder;

      if (uncastInputCoder instanceof WindowedValue.WindowedValueCoder) {

        inputCoder =

            (KvCoder<KeyT, InputT>)

                ((WindowedValueCoder<KV<KeyT, InputT>>) uncastInputCoder).getValueCoder();

      } else {

        inputCoder = (KvCoder<KeyT, InputT>) rehydratedComponents.getCoder(mainInput.getCoderId());

      }

      Coder<KeyT> keyCoder = inputCoder.getKeyCoder();



      CombinePayload combinePayload = CombinePayload.parseFrom(pTransform.getSpec().getPayload());

      CombineFn<InputT, AccumT, ?> combineFn =

          (CombineFn)

              SerializableUtils.deserializeFromByteArray(

                  combinePayload.getCombineFn().getSpec().getPayload().toByteArray(), "CombineFn");

      Coder<AccumT> accumCoder =

          (Coder<AccumT>) rehydratedComponents.getCoder(combinePayload.getAccumulatorCoderId());



      FnDataReceiver<WindowedValue<KV<KeyT, AccumT>>> consumer =

          (FnDataReceiver)

              pCollectionConsumerRegistry.getMultiplexingConsumer(

                  Iterables.getOnlyElement(pTransform.getOutputsMap().values()));



      PrecombineRunner<KeyT, InputT, AccumT> runner =

          new PrecombineRunner<>(pipelineOptions, combineFn, consumer, keyCoder, accumCoder);



      // Register the appropriate handlers.

      startFunctionRegistry.register(pTransformId, runner::startBundle);

      pCollectionConsumerRegistry.register(

          Iterables.getOnlyElement(pTransform.getInputsMap().values()),

          pTransformId,

          (FnDataReceiver)

              (FnDataReceiver<WindowedValue<KV<KeyT, InputT>>>) runner::processElement);

      finishFunctionRegistry.register(pTransformId, runner::finishBundle);



      return runner;

    }

  }



  static <KeyT, AccumT>

  ThrowingFunction<KV<KeyT, Iterable<AccumT>>, KV<KeyT, AccumT>>

          createMergeAccumulatorsMapFunction(String pTransformId, PTransform pTransform)

              throws IOException {

    CombinePayload combinePayload = CombinePayload.parseFrom(pTransform.getSpec().getPayload());

    CombineFn<?, AccumT, ?> combineFn =

        (CombineFn)

            SerializableUtils.deserializeFromByteArray(

                combinePayload.getCombineFn().getSpec().getPayload().toByteArray(), "CombineFn");



    return (KV<KeyT, Iterable<AccumT>> input) ->

        KV.of(input.getKey(), combineFn.mergeAccumulators(input.getValue()));

  }



  static <KeyT, AccumT, OutputT>

  ThrowingFunction<KV<KeyT, AccumT>, KV<KeyT, OutputT>> createExtractOutputsMapFunction(

          String pTransformId, PTransform pTransform) throws IOException {

    CombinePayload combinePayload = CombinePayload.parseFrom(pTransform.getSpec().getPayload());

    CombineFn<?, AccumT, OutputT> combineFn =

        (CombineFn)

            SerializableUtils.deserializeFromByteArray(

                combinePayload.getCombineFn().getSpec().getPayload().toByteArray(), "CombineFn");



    return (KV<KeyT, AccumT> input) ->

        KV.of(input.getKey(), combineFn.extractOutput(input.getValue()));

  }



  static <KeyT, InputT, AccumT, OutputT>

  ThrowingFunction<KV<KeyT, Iterable<InputT>>, KV<KeyT, OutputT>>

          createCombineGroupedValuesMapFunction(String pTransformId, PTransform pTransform)

              throws IOException {

    CombinePayload combinePayload = CombinePayload.parseFrom(pTransform.getSpec().getPayload());

    CombineFn<InputT, AccumT, OutputT> combineFn =

        (CombineFn)

            SerializableUtils.deserializeFromByteArray(

                combinePayload.getCombineFn().getSpec().getPayload().toByteArray(), "CombineFn");



    return (KV<KeyT, Iterable<InputT>> input) -> {

      return KV.of(input.getKey(), combineFn.apply(input.getValue()));

    };

  }

}